{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Stage-I ckpt path \n",
    "stage_1_path = \"./checkpoints/stage1.pt\"\n",
    "stage_2_path = \"./checkpoints/stage2.pt\"\n",
    "save_dir=\"vis_270p_1080p\"\n",
    "# 2 ~ 3\n",
    "shift_t = 2.5\n",
    "# 4 ~ 6\n",
    "sample_step = 5\n",
    "# 10 ~ 13\n",
    "cfg_second = 13\n",
    " # 650 ~ 750\n",
    "deg_latent_strength=675\n",
    "# stage_1_hw \n",
    "\n",
    "#TODO Stage I CFG here\n",
    "cfg_first = 8\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "current_directory = os.getcwd()\n",
    "os.chdir(os.path.dirname(current_directory))\n",
    "new_directory = os.getcwd()\n",
    "print(f\"working directory: {new_directory}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "import argparse\n",
    "import torch\n",
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "from sat.model.base_model import get_model\n",
    "from arguments import get_args\n",
    "from torchvision.io.video import write_video\n",
    "\n",
    "from flow_video import FlowEngine\n",
    "from diffusion_video import SATVideoDiffusionEngine\n",
    "\n",
    "import os\n",
    "from utils import disable_all_init, decode, prepare_input, save_memory_encode_first_stage, save_mem_decode, seed_everything\n",
    "disable_all_init()\n",
    "\n",
    "\n",
    "def init_model(model, second_model, args, second_args):\n",
    "    share_cache = dict()\n",
    "    second_share_cache = dict()\n",
    "    if hasattr(args, 'share_cache_config'):\n",
    "        for k, v in args.share_cache_config.items():\n",
    "            share_cache[k] = v\n",
    "    if hasattr(second_args, 'share_cache_config'):\n",
    "        for k, v in second_args.share_cache_config.items():\n",
    "            second_share_cache[k] = v\n",
    "\n",
    "    for n, m in model.named_modules():\n",
    "        m.share_cache = share_cache        \n",
    "        if hasattr(m, \"register_new_modules\"):\n",
    "            m.register_new_modules()\n",
    "    for n, m in second_model.named_modules():\n",
    "        m.share_cache = second_share_cache        \n",
    "        if hasattr(m, \"register_new_modules\"):\n",
    "            m.register_new_modules()        \n",
    "\n",
    "    weight_path = args.inf_ckpt\n",
    "    weight = torch.load(weight_path, map_location=\"cpu\")\n",
    "    if \"model.diffusion_model.mixins.pos_embed.freqs_sin\" in weight[\"module\"]:\n",
    "        del weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_sin\"]\n",
    "        del weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_cos\"]\n",
    "    msg = model.load_state_dict(weight[\"module\"], strict=False)\n",
    "    print(msg)\n",
    "    second_weight_path = args.inf_ckpt2\n",
    "    second_weight = torch.load(second_weight_path, map_location=\"cpu\")\n",
    "    if \"model.diffusion_model.mixins.pos_embed.freqs_sin\" in second_weight[\"module\"]:\n",
    "        del second_weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_sin\"]\n",
    "        del second_weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_cos\"]\n",
    "    second_msg = second_model.load_state_dict(second_weight[\"module\"], strict=False)\n",
    "    print(second_msg)\n",
    "\n",
    "def get_first_results(model, text, num_frames, H, W, neg_prompt=None):\n",
    "    \"\"\"Get first Stage results.\n",
    "\n",
    "    Args:\n",
    "        model (nn.Module): first stage model.\n",
    "        text (str): text prompt\n",
    "        num_frames (int): number of frames\n",
    "        H (int): height of the first stage results\n",
    "        W (int): width of the first stage results\n",
    "        neg_prompt (str): negative prompt\n",
    "\n",
    "    Returns:\n",
    "        Tensor: first stage video.\n",
    "    \"\"\"\n",
    "    device = 'cuda'\n",
    "    T = 1 + (num_frames - 1) // 4\n",
    "    F = 8\n",
    "    motion_text_prefix = [\n",
    "        'very low motion,',\n",
    "        'low motion,',\n",
    "        'medium motion,',\n",
    "        'high motion,',\n",
    "        'very high motion,',\n",
    "    ]\n",
    "    pos_prompt = \"\"\n",
    "    if neg_prompt is None:\n",
    "        neg_prompt = \"\"\n",
    "    with torch.no_grad():\n",
    "        model.to('cuda')\n",
    "        input_negative_prompt = motion_text_prefix[\n",
    "            0] + ', ' + motion_text_prefix[1] + neg_prompt\n",
    "        c, uc = prepare_input(text,\n",
    "                              model,\n",
    "                              T,\n",
    "                              negative_prompt=input_negative_prompt,\n",
    "                              pos_prompt=pos_prompt)\n",
    "        with torch.no_grad(), torch.amp.autocast(enabled=True,\n",
    "                                                 device_type='cuda',\n",
    "                                                 dtype=torch.bfloat16):\n",
    "            samples_z = model.sample(\n",
    "                c,\n",
    "                uc=uc,\n",
    "                batch_size=1,\n",
    "                shape=(T, 16, H // F, W // F),\n",
    "                num_steps=model.share_cache.get('first_sample_step', None),\n",
    "            )\n",
    "        samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()\n",
    "\n",
    "        model.to('cpu')\n",
    "        torch.cuda.empty_cache()\n",
    "        first_stage_model = model.first_stage_model\n",
    "        first_stage_model = first_stage_model.to(device)\n",
    "\n",
    "        latent = 1.0 / model.scale_factor * samples_z\n",
    "\n",
    "        samples = decode(first_stage_model, latent)\n",
    "    model.to('cpu')\n",
    "    return samples\n",
    "def get_second_results(model, text, first_stage_samples, num_frames):\n",
    "    \"\"\"Get second Stage results.\n",
    "\n",
    "    Args:\n",
    "        model (nn.Module): second stage model.\n",
    "        text (str): text prompt\n",
    "        first_stage_samples (Tensor): first stage results\n",
    "        num_frames (int): number of frames\n",
    "    Returns:\n",
    "        Tensor: second stage results.\n",
    "    \"\"\"\n",
    "\n",
    "    t, h, w, c = first_stage_samples.shape\n",
    "    first_stage_samples = first_stage_samples[:num_frames]\n",
    "    first_stage_samples = (first_stage_samples / 255.)\n",
    "    first_stage_samples = (first_stage_samples - 0.5) / 0.5\n",
    "\n",
    "    target_size = model.share_cache.get('target_size', None)\n",
    "    if target_size is None:\n",
    "        upscale_factor = model.share_cache.get('upscale_factor', 8)\n",
    "        H = int(h * upscale_factor) // 16 * 16\n",
    "        W = int(w * upscale_factor) // 16 * 16\n",
    "    else:\n",
    "        H, W = target_size\n",
    "        H = H // 16 * 16\n",
    "        W = W // 16 * 16\n",
    "\n",
    "    first_stage_samples = first_stage_samples.permute(0, 3, 1, 2).to('cuda')\n",
    "\n",
    "    ref_x = torch.nn.functional.interpolate(first_stage_samples,\n",
    "                                            size=(H, W),\n",
    "                                            mode='bilinear',\n",
    "                                            align_corners=False,\n",
    "                                            antialias=True)\n",
    "    ref_x = ref_x[:num_frames][None]\n",
    "\n",
    "    ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()\n",
    "\n",
    "    first_stage_model = model.first_stage_model\n",
    "    print(f'start encoding first stage results to high resolution')\n",
    "    with torch.no_grad():\n",
    "        first_stage_dtype = next(model.first_stage_model.parameters()).dtype\n",
    "        model.first_stage_model.cuda()\n",
    "        ref_x = save_memory_encode_first_stage(\n",
    "            ref_x.contiguous().to(first_stage_dtype).cuda(), model)\n",
    "\n",
    "    ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()\n",
    "    ref_x = ref_x.to(model.dtype)\n",
    "    print(f'finish encoding first stage results, and starting stage II')\n",
    "\n",
    "    device = 'cuda'\n",
    "\n",
    "    model.to(device)\n",
    "\n",
    "    pos_prompt = ''\n",
    "    input_negative_prompt = \"\"\n",
    "\n",
    "    c, uc = prepare_input(text,\n",
    "                          model,\n",
    "                          num_frames,\n",
    "                          negative_prompt=input_negative_prompt,\n",
    "                          pos_prompt=pos_prompt)\n",
    "\n",
    "    T = 1 + (num_frames - 1) // 4\n",
    "    F = 8\n",
    "    with torch.no_grad(), torch.amp.autocast(enabled=True,\n",
    "                                             device_type='cuda',\n",
    "                                             dtype=torch.bfloat16):\n",
    "        samples_z = model.sample(\n",
    "            ref_x,\n",
    "            c,\n",
    "            uc=uc,\n",
    "            batch_size=1,\n",
    "            shape=(T, 16, H // F, W // F),\n",
    "            num_steps=model.share_cache.get('sample_step', 5),\n",
    "            method='euler',\n",
    "            cfg=model.share_cache.get('cfg', 7.5),\n",
    "        )\n",
    "        samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()\n",
    "\n",
    "        model.to('cpu')\n",
    "        torch.cuda.empty_cache()\n",
    "        first_stage_model = model.first_stage_model\n",
    "        first_stage_model = first_stage_model.to(device)\n",
    "\n",
    "        latent = 1.0 / model.scale_factor * samples_z\n",
    "        print(f'start spatiotemporal slice decoding')\n",
    "        samples = save_mem_decode(first_stage_model, latent)\n",
    "        print(f'finish spatiotemporal slice decoding')\n",
    "        model.to('cpu')\n",
    "    return samples\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "os.environ[\"LOCAL_RANK\"] = \"0\"\n",
    "os.environ[\"WORLD_SIZE\"] = \"1\"\n",
    "os.environ[\"RANK\"] = \"0\"\n",
    "os.environ[\"MASTER_ADDR\"] = \"0.0.0.0\"\n",
    "os.environ[\"MASTER_PORT\"] = \"12345\"\n",
    "\n",
    "py_parser = argparse.ArgumentParser(add_help=False)\n",
    "args_list = [\n",
    "    \"--base\", \"flashvideo/configs/stage1.yaml\",\n",
    "    \"--second\", \"flashvideo/configs/stage2.yaml\",\n",
    "    \"--inf-ckpt\", stage_1_path,\n",
    "    \"--inf-ckpt2\", stage_2_path,\n",
    "]\n",
    "known, args_list = py_parser.parse_known_args(args=args_list)\n",
    "second_args_list = copy.deepcopy(args_list)\n",
    "\n",
    "\n",
    "args = get_args(args_list)\n",
    "args = argparse.Namespace(**vars(args), **vars(known))\n",
    "del args.deepspeed_config\n",
    "args.model_config.first_stage_config.params.cp_size = 1\n",
    "args.model_config.network_config.params.transformer_args.model_parallel_size = 1\n",
    "args.model_config.network_config.params.transformer_args.checkpoint_activations = False\n",
    "args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False\n",
    "\n",
    "second_args_list[1] = args.second[0]\n",
    "second_args = get_args(second_args_list)\n",
    "second_args = argparse.Namespace(**vars(second_args), **vars(known))\n",
    "del second_args.deepspeed_config\n",
    "second_args.model_config.first_stage_config.params.cp_size = 1\n",
    "second_args.model_config.network_config.params.transformer_args.model_parallel_size = 1\n",
    "second_args.model_config.network_config.params.transformer_args.checkpoint_activations = False\n",
    "second_args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_cls=SATVideoDiffusionEngine\n",
    "second_model_cls=FlowEngine\n",
    "local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n",
    "torch.cuda.set_device(local_rank)\n",
    "\n",
    "second_model = get_model(second_args, second_model_cls)\n",
    "\n",
    "model = get_model(args, model_cls)\n",
    "    \n",
    "init_model(model, second_model, args, second_args )\n",
    "    \n",
    "model.eval()\n",
    "second_model.eval()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for n, m in model.named_modules():\n",
    "    if hasattr(m, \"merge_lora\"):\n",
    "        m.merge_lora()\n",
    "        print(f\"merge lora of {n}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_frames = 49\n",
    "second_num_frames = 49 \n",
    "\n",
    "stage_1_hw = (270, 480)  \n",
    "stage_2_hw = (1080, 1920) \n",
    "\n",
    "# make sure all can be divided by 16\n",
    "stage_1_hw = (stage_1_hw[0] // 16 * 16, stage_1_hw[1] // 16 * 16)\n",
    "stage_2_hw = (stage_2_hw[0] // 16 * 16, stage_2_hw[1] // 16 * 16)\n",
    "\n",
    "sample_func = model.sample\n",
    "T, H, W, C, F = num_frames, stage_1_hw[0], stage_1_hw[1], args.latent_channels, 8\n",
    "S_T, S_H, S_W, S_C, S_F = second_num_frames, stage_2_hw[0], stage_2_hw[1], args.latent_channels, 8\n",
    "\n",
    "\n",
    "    \n",
    "seed_everything(0)\n",
    "\n",
    "text = \" Sunny day, The camera smoothly pushes in through an ornate garden archway, delicately adorned with climbing ivy. \\\n",
    "    Beyond the archway, a secret, tranquil garden is revealed, brimming with a vibrant array of blooming flowers \\\n",
    "        in a myriad of colors. A beautiful young woman with long wavy brown hair, she is smile to the camera , \\\n",
    "    wearing a red hat sits  holding a dog , the red hat has rich fabric texture  \\\n",
    "        wearing black pleated skirt and yellow sweater \"\n",
    "\n",
    "\n",
    "neg_text = \"\"\n",
    "\n",
    "if os.path.exists(save_dir) is False:\n",
    "    os.makedirs(save_dir)\n",
    "enu_index = \"1\"\n",
    "model.share_cache[\"cfg\"] = cfg_first\n",
    "\n",
    "first_stage_samples = get_first_results(model, text, num_frames, H, W, neg_text)\n",
    "\n",
    "print(f\"save to {save_dir}/{enu_index}_num_frame_{num_frames}.mp4\")\n",
    "write_video(filename=f'./{save_dir}/{enu_index}_num_frame_{num_frames}.mp4', \n",
    "                fps=8, \n",
    "                video_array= first_stage_samples, \n",
    "                options = { 'crf': '14' })\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "second_num_frames = 49\n",
    "second_model.share_cache[\"ref_noise_step\"] = deg_latent_strength\n",
    "second_model.share_cache[\"sample_ref_noise_step\"] = deg_latent_strength\n",
    "second_model.share_cache.pop(\"ref_noise_step_range\", None)\n",
    "second_model.share_cache[\"target_size\"] = stage_2_hw\n",
    "second_model.share_cache[\"shift_t\"] = shift_t\n",
    "second_model.share_cache[\"sample_step\"] = sample_step\n",
    "second_model.share_cache[\"cfg\"] = cfg_second\n",
    "post_fix = f'''noise_{second_model.share_cache[\"ref_noise_step\"]}_step_{second_model.share_cache[\"sample_step\"]}_cfg_{second_model.share_cache[\"cfg\"]}_shift_{second_model.share_cache[\"shift_t\"]}_size_{stage_2_hw[0]}x{stage_2_hw[1]}'''\n",
    "second_model.share_cache[\"time_size_embedding\"] = True\n",
    "second_stage_samples = get_second_results(second_model, \n",
    "                                            text, \n",
    "                                            first_stage_samples, \n",
    "                                            second_num_frames)\n",
    "\n",
    "print(f\"save to {save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}.mp4\")\n",
    "write_video(filename=f'./{save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}_second.mp4', \n",
    "                fps=8, \n",
    "                video_array= second_stage_samples.cpu(), \n",
    "                options = { 'crf': '14' })\n",
    "\n",
    "\n",
    "# save joint video \n",
    "part_first_stage = first_stage_samples[:second_num_frames]\n",
    "\n",
    "target_h, target_w = second_stage_samples.shape[1], second_stage_samples.shape[2]\n",
    "part_first_stage = torch.nn.functional.interpolate(part_first_stage.permute(0, 3, 1, 2).contiguous(),\n",
    "                                                    size=(target_h, target_w),\n",
    "                                                    mode=\"bilinear\",\n",
    "                                                    align_corners=False, \n",
    "                                                    antialias=True)\n",
    "part_first_stage = part_first_stage.permute(0, 2, 3, 1).contiguous()\n",
    "\n",
    "\n",
    "joint_video = torch.cat([part_first_stage.cpu(), second_stage_samples.cpu()], dim=-2)\n",
    "print(f'./{save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}_joint.mp4')\n",
    "write_video(filename=f'./{save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}_joint.mp4',\n",
    "                fps=8,\n",
    "                video_array=joint_video.cpu(),\n",
    "                options={'crf': '15'})   \n"
   ]
  }
 ],
 "metadata": {
  "fileId": "c6eed2be-3101-492e-a984-783ecbc70a34",
  "filePath": "/mnt/bn/foundation-ads/shilong/conda/code/cogvideo-5b/sat/demo_ab.ipynb",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
